{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "719a374f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Copyright 2022 Google LLC, licensed under the Apache License, Version 2.0 (the \"License\")\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c36e88eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve\n",
    "from scipy.interpolate import interp1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58087d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_log_likelihoods(vis_dist, mode):\n",
    "  \"\"\"Loads log likelihoods from probs.pkl files.\n",
    "  \n",
    "  Args:\n",
    "    vis_dist: Visible dist of the model\n",
    "    mode: \"grayscale\" or \"color\"\n",
    "  Returns:\n",
    "    A nested dictionary containing the log likelihoods\n",
    "  \"\"\"\n",
    "  \n",
    "  if mode == 'grayscale':\n",
    "    datasets = [\n",
    "      'mnist',\n",
    "      'fashion_mnist',\n",
    "      'emnist/letters',\n",
    "      'sign_lang',\n",
    "    ]\n",
    "    nf = 32\n",
    "    cs_hist = 'adhisteq'\n",
    "  else:\n",
    "    datasets = [\n",
    "      'svhn_cropped',\n",
    "      'cifar10',\n",
    "      'celeb_a',\n",
    "      'gtsrb',\n",
    "      'compcars',\n",
    "    ]\n",
    "    nf = 64\n",
    "    cs_hist = 'histeq'\n",
    "\n",
    "  log_probs = defaultdict(lambda: defaultdict(dict))\n",
    "  for id_data in datasets:\n",
    "    for norm in [None, 'pctile-5', cs_hist]:\n",
    "      with open(\n",
    "          (f'vae_ood/models/{vis_dist}/'\n",
    "           f'{id_data.replace(\"/\", \"_\")}-{norm}-zdim_20-lr_0.0005-bs_64-nf_{nf}/'\n",
    "           'probs.pkl'),\n",
    "          'rb') as f:\n",
    "        d = pickle.load(f)\n",
    "      for i, ood_data in enumerate(datasets + ['noise']):\n",
    "        log_probs[f'{id_data}-{norm}'][f'{ood_data}-{norm}']['orig_probs'] = d['orig_probs'][ood_data]\n",
    "        log_probs[f'{id_data}-{norm}'][f'{ood_data}-{norm}']['corr_probs'] = d['corr_probs'][ood_data]\n",
    "  return log_probs`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8915ffa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_metrics(log_probs):\n",
    "  \"\"\"Computes AUROC, AUPRC and FPR@80 metrics using probs.pkl files.\n",
    "  \n",
    "  Args:\n",
    "    log_probs: original and corrected log likelihoods for all ID-OOD\n",
    "               pairs as returned by get_log_likelihoods()\n",
    "  Returns:\n",
    "    A nested dictionary containing the metrics\n",
    "  \"\"\"\n",
    "\n",
    "  metrics = defaultdict(lambda: defaultdict(dict))\n",
    "  for id_data in log_probs:\n",
    "    for ood_data in log_probs[id_data]:\n",
    "      labels_concat = np.concatenate(\n",
    "          [np.zeros_like(log_probs[id_data][ood_data]['orig_probs'][:10000]),\n",
    "           np.ones_like(log_probs[id_data][id_data]['orig_probs'][:10000])]) \n",
    "      lls_concat = np.concatenate(\n",
    "          [log_probs[id_data][ood_data]['orig_probs'][:10000],\n",
    "           log_probs[id_data][id_data]['orig_probs'][:10000]])\n",
    "      orig_roc = roc_auc_score(labels_concat, lls_concat)\n",
    "      orig_prc = average_precision_score(labels_concat, lls_concat)\n",
    "      fpr, tpr, thresholds = roc_curve(labels_concat, lls_concat, pos_label=1, drop_intermediate=False)\n",
    "      ind = np.argmax(tpr>0.8)  \n",
    "      x = np.array((tpr[ind-1], tpr[ind]))\n",
    "      y = np.array((fpr[ind-1], fpr[ind]))    \n",
    "      f = interp1d(x,y)\n",
    "      orig_fpr = f(0.8)\n",
    "      metrics[id_data][ood_data]['orig_roc'] = orig_roc*100\n",
    "      metrics[id_data][ood_data]['orig_prc'] = orig_prc*100\n",
    "      metrics[id_data][ood_data]['orig_fpr'] = orig_fpr*100\n",
    "\n",
    "      lls_concat = np.concatenate(\n",
    "          [log_probs[id_data][ood_data]['corr_probs'][:10000],\n",
    "           log_probs[id_data][id_data]['corr_probs'][:10000]])\n",
    "      corr_roc = roc_auc_score(labels_concat, lls_concat)\n",
    "      corr_prc = average_precision_score(labels_concat, lls_concat)\n",
    "      fpr, tpr, thresholds = roc_curve(labels_concat, lls_concat, pos_label=1, drop_intermediate=False)\n",
    "      ind = np.argmax(tpr>0.8)  \n",
    "      x = np.array((tpr[ind-1], tpr[ind]))\n",
    "      y = np.array((fpr[ind-1], fpr[ind]))    \n",
    "      f = interp1d(x,y)\n",
    "      corr_fpr = f(0.8)\n",
    "      metrics[id_data][ood_data]['corr_roc'] = corr_roc*100\n",
    "      metrics[id_data][ood_data]['corr_prc'] = corr_prc*100\n",
    "      metrics[id_data][ood_data]['corr_fpr'] = corr_fpr*100\n",
    "  return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8a9e1ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_metrics(metrics):\n",
    "  \"\"\"Returns key metrics in a dataframe.\n",
    "  \n",
    "  Args:\n",
    "    metrics: metrics dict returned by get_metrics()\n",
    "  Returns:\n",
    "    A dataframe containing key metrics\n",
    "  \"\"\"\n",
    "  df = pd.DataFrame(\n",
    "    columns = ['ID Data ↓ OOD Data →'] + \n",
    "    list(set(dname.split('-')[0] for dname in metrics.keys())) + ['noise'])\n",
    "  for id_data in df.columns[1:-1]:\n",
    "    df_row = {'ID Data ↓ OOD Data →': id_data}\n",
    "    for ood_data in df.columns[1:]:\n",
    "      df_row[ood_data] = [\n",
    "          int(round(metrics[f'{id_data}-None'][f'{ood_data}-None']['orig_roc'],\n",
    "                    0)),\n",
    "          int(round(metrics[f'{id_data}-pctile-5'][f'{ood_data}-pctile-5']['corr_roc'],\n",
    "                    0))\n",
    "          ]\n",
    "    df = df.append(df_row, ignore_index=True)\n",
    "  return df.set_index('ID Data ↓ OOD Data →', drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba206245",
   "metadata": {},
   "outputs": [],
   "source": [
    "cb_grayscale_lls = get_log_likelihoods('cont_bernoulli', 'grayscale')\n",
    "cb_grayscale_metrics = get_metrics(cb_grayscale_lls)\n",
    "print_metrics(cb_grayscale_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a732b60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_grayscale_lls = get_log_likelihoods('cont_bernoulli', 'grayscale')\n",
    "cat_grayscale_metrics = get_metrics(cat_grayscale_lls)\n",
    "print_metrics(cat_grayscale_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d77eaa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "cb_color_lls = get_log_likelihoods('cont_bernoulli', 'color')\n",
    "cb_color_metrics = get_metrics(cb_color_lls)\n",
    "print_metrics(cb_color_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9907f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_color_lls = get_log_likelihoods('categorical', 'color')\n",
    "cat_color_metrics = get_metrics(cat_color_lls)\n",
    "print_metrics(cat_color_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e26b5b3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
